import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.utils.data as torch_data_util
import math

class CosCompare(nn.Module):
    """
    Contrastive loss function.
    Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
    """

    def __init__(self, margin=0.0):
        super(CosCompare, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        cos_distance = F.cosine_similarity(output1, output2, self.dim, self.eps)
        loss_contrastive = torch.mean((1-label) * torch.pow(cos_distance, 2) +     # calmp夹断用法
                                      (label) * torch.pow(torch.clamp(self.margin - cos_distance, min=0.0), 2))     
 

        return loss_contrastive